from ast import arg
import json
import logging
import math
from operator import mod
import os
from pprint import pformat

import datasets
import numpy as np
import torch
import transformers
from accelerate import Accelerator
from fastcore.all import *
from tqdm.auto import tqdm
from transformers import AdamW, AutoTokenizer, BertTokenizer, get_scheduler, set_seed

from models_my import get_auto_model, GPLinker
from torch.optim.lr_scheduler import LambdaLR
from utils.args import parse_args
from utils.data import get_dataloader_and_dataset
from utils.data_duie import get_data_loader
from utils.postprocess import postprocess_gplinker, postprocess_tplinker_plus
from utils.utils import get_writer, try_remove_old_ckpt, write_json
logger = logging.getLogger(__name__)


def get_linear_warmup(optizer, nums_warmup_step, nums_train_step, last_epoch=-1):
  """
  线性 warmup。

  Args:
      optizer ([type]): [description]
      nums_warmup_step ([type]): [description]
      nums_train_step ([type]): [description]
      last_epoch (int, optional): [description]. Defaults to -1.
  """
  def warmup(cur_step):
      if cur_step < nums_warmup_step:
          return cur_step / max(0, nums_warmup_step)
      return max(0, nums_train_step - cur_step)/max(1, nums_train_step - nums_warmup_step)
  return LambdaLR(optizer, warmup, last_epoch)

@torch.no_grad()
def evaluate(
    args,
    model,
    dev_dataloader,
    accelerator,
    global_steps=0,
    threshold=0,
    write_predictions=True,
):
    model.eval()
    all_predictions = []
    bar_eval = tqdm(range(len(dev_dataloader)), desc="Evaluating")
    for batch in dev_dataloader:
        offset_mappings = batch.pop("offset_mapping")
        texts = batch.pop("text")
        for k,v in batch.items():
            batch[k]= v.to('cuda')
        outputs = model(**batch)[0]
        if args.method == "gplinker":
            outputs_gathered = postprocess_gplinker(
                args, accelerator.gather(
                    outputs), offset_mappings, texts, threshold
            )
        elif args.method == "tplinker_plus":
            outputs_gathered = postprocess_tplinker_plus(
                args,
                accelerator.gather(outputs),
                offset_mappings,
                texts,
                batch["input_ids"].size(1),
            )
        else:
            raise ValueError(
                "args.method should be chosen from ['gplinker', 'tplinker_plus']!"
            )
        all_predictions.extend(outputs_gathered)
        bar_eval.update(1)
    bar_eval.close()

    X, Y, Z = 1e-10, 1e-10, 1e-10
    if write_predictions:
        pred_dir = os.path.join(args.output_dir, "preds")
        os.makedirs(pred_dir, exist_ok=True)
        pred_file = os.path.join(
            pred_dir, f"{global_steps}_step_preds_{args.method}.json")
        f = open(pred_file, "w", encoding="utf-8")
    for preds, golds, text in zip(
        all_predictions,
        dev_dataloader.dataset.raw_data["spo_list"],
        dev_dataloader.dataset.raw_data["text"],
    ):
        R = set(preds)
        T = set([tuple(g) for g in golds])
        X += len(R & T)
        Y += len(R)
        Z += len(T)
        if write_predictions:
            s = json.dumps(
                {
                    "text": text,
                    "spo_list": list(T),
                    "spo_list_pred": list(R),
                    "new": list(R - T),
                    "lack": list(T - R),
                },
                ensure_ascii=False,
            )
            f.write(s + "\n")
    if write_predictions:
        f.close()
    f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
    model.train()

    return {"f1": f1, "precision": precision, "recall": recall, "tp": X, 'tpfp': Y, 'tpfn': Z}


def main():
    args = parse_args()
    accelerator = Accelerator()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(
                os.path.join(args.output_dir, "run.log"),
                mode="w",
                encoding="utf-8",
            )
        ],
    )
    logger.info(accelerator.state)
    logger.setLevel(
        logging.INFO if accelerator.is_local_main_process else logging.ERROR
    )
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    if args.seed is not None:
        set_seed(args.seed)

    predicate2id = {}
    id2predicate = {}
    predicates = []
    with open("data/all_50_schemas", "r", encoding="utf-8") as f:
        for l in f:
            l = json.loads(l)
            if l["predicate"] not in predicate2id:
                id2predicate[len(predicate2id)] = l["predicate"]
                predicate2id[l["predicate"]] = len(predicate2id)
                predicates.append(l["predicate"])
    args.predicate2id = predicate2id
    args.id2predicate = id2predicate
    args.num_labels = len(id2predicate)

    # if args.method == "tplinker_plus":
    #     link_types = [
    #         "SH2OH",  # subject head to object head
    #         "OH2SH",  # object head to subject head
    #         "ST2OT",  # subject tail to object tail
    #         "OT2ST",  # object tail to subject tail
    #     ]
    #     tags = []
    #     for lk in link_types:
    #         for rel in predicate2id.keys():
    #             tags.append("=".join([rel, lk]))
    #     tags.append("DEFAULT=EH2ET")
    #     args.tag2id = {t: idx for idx, t in enumerate(tags)}
    #     args.id2tag = {idx: t for t, idx in args.tag2id.items()}

    # tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.pretrained_model_name_or_path
    # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name
    #                                           )
    # model = get_auto_model(args.model_type, args.method).from_pretrained(
    #     args.pretrained_model_name_or_path,
    #     predicate2id=predicate2id,
    #     cache_dir=args.model_cache_dir,
    #     use_efficient=args.use_efficient,
    # )
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = GPLinker(args.pretrained_model_name_or_path, predicate2id)
    tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_name_or_path)
    
    train_dataloader = get_data_loader([f'./resources/datasets/duie/train_data.json'], args.pretrained_model_name_or_path, args.per_device_train_batch_size, 
                                       args.max_length, 'train', predicates, 'handle', 'features',
                                       nums_save=2000, features_num=2000, do_lower_case=False)
    
    (_, dev_dataloader) = get_dataloader_and_dataset(
        args,
        tokenizer,
        predicate2id=predicate2id,
        use_fp16=accelerator.use_fp16,
        text_column_name="text",
        label_column_name="spo_list",
    )
    # params = [p for p in model.parameters() if p.requires_grad==True]
    # # todo: 试下：1、加上 weight_decay; 2、加上 adam epsilon
    # opt = AdamW(params, lr=args.learning_rate, eps=args.adam_epsilon)
    no_decay = ["bias", "LayerNorm.weight", "norm"]

    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    opt = torch.optim.AdamW(
        optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon
    )
    steps_total = len(train_dataloader) * args.num_train_epochs
    lr_schedule = get_linear_warmup(opt, steps_total * args.num_warmup_steps_or_radios, steps_total)
    # lr_schedule = get_scheduler(
    #     name=args.lr_scheduler_type,
    #     optimizer=opt,
    #     num_warmup_steps= steps_total * args.num_warmup_steps_or_radios, #args.num_warmup_steps,
    #     num_training_steps=steps_total, #args.max_train_steps,
    # )
    # todo: 加上 warmup
    writer = get_writer(args)
    
    model.to(device)
    bar_train = tqdm(range(len(train_dataloader) * args.num_train_epochs))
    loss_tl, loss_pre = 0.0, 0.0
    global_steps = 0
    model.train()
    for epoch in range(args.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            opt.zero_grad()
            global_steps += 1
            input_ids, _, entity, head, tail = [v.to(device) for v in batch[:5]]
            outputs = model(input_ids, None, [entity, head, tail])
            
            # batch['input_ids'] = batch['input_ids'].to(device)
            # batch['attention_mask'] = batch['attention_mask'].to(device)
            # batch['labels'] = [v.to(device) for v in batch['labels']]
            # outputs = model(**batch)
            loss = outputs[0]
            loss_tl += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
            opt.step()
            lr_schedule.step()
            bar_train.update(1)
            if global_steps > 0 and global_steps % args.logging_steps == 0:
                loss_avg = ((loss_tl - loss_pre) / args.logging_steps)
                loss_pre = loss_tl
                bar_train.set_description(f'global_steps {step} - lr: {lr_schedule.get_last_lr()[-1]:.8f} - loss: {loss_avg:.4f}')
                logger.info(f'global_steps {step} - lr: {lr_schedule.get_last_lr()[-1]:.8f} - loss: {loss_avg:.4f}')
                writer.add_scalar('loss', loss_avg, global_steps)
                
                
            if global_steps > 0 and global_steps % args.save_steps == 0:
                dev_metric = evaluate(args, model, dev_dataloader, accelerator, global_steps, 0, True)
                accelerator.print("##--------------------- Dev")
                logger.info("##--------------------- Dev")
                accelerator.print("-" * 80)
                logger.info("-" * 80)
                accelerator.print(f"global_steps = {global_steps}")
                logger.info(f"global_steps = {global_steps}")
                for k,v in dev_metric.items():
                    writer.add_scalar(k, v, global_steps)
                    logger.info(f"{k} = {v}")
                    accelerator.print(f"{k} = {v}")
                accelerator.print("-" * 80)
                logger.info("-" * 80)
                accelerator.print("**--------------------- Dev End")
                logger.info("**--------------------- Dev End")
    bar_train.close()            
    # no_decay = ["bias", "LayerNorm.weight", "norm"]

    # optimizer_grouped_parameters = [
    #     {
    #         "params": [
    #             p
    #             for n, p in model.named_parameters()
    #             if not any(nd in n for nd in no_decay)
    #         ],
    #         "weight_decay": args.weight_decay,
    #     },
    #     {
    #         "params": [
    #             p
    #             for n, p in model.named_parameters()
    #             if any(nd in n for nd in no_decay)
    #         ],
    #         "weight_decay": 0.0,
    #     },
    # ]
    # optimizer = torch.optim.AdamW(
    #     optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon
    # )

    # model, optimizer, train_dataloader = accelerator.prepare(
    #     model, optimizer, train_dataloader
    # )

    # num_update_steps_per_epoch = math.ceil(
    #     len(train_dataloader) / args.gradient_accumulation_steps
    # )
    # if args.max_train_steps is None:
    #     args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # else:
    #     args.num_train_epochs = math.ceil(
    #         args.max_train_steps / num_update_steps_per_epoch
    #     )
    # args.num_warmup_steps = (
    #     math.ceil(args.max_train_steps * args.num_warmup_steps_or_radios)
    #     if isinstance(args.num_warmup_steps_or_radios, float)
    #     else args.num_warmup_steps_or_radios
    # )
    # lr_scheduler = get_scheduler(
    #     name=args.lr_scheduler_type,
    #     optimizer=optimizer,
    #     num_warmup_steps=args.num_warmup_steps,
    #     num_training_steps=args.max_train_steps,
    # )

    # # Train!
    # args.total_batch_size = (
    #     args.per_device_train_batch_size
    #     * accelerator.num_processes
    #     * args.gradient_accumulation_steps
    # )

    # logger.info("********** Running training **********")
    # logger.info(f"  Num examples = {len(train_dataloader.dataset)}")
    # logger.info(f"  Num Epochs = {args.num_train_epochs}")
    # logger.info(
    #     f"  Instantaneous batch size per device = {args.per_device_train_batch_size}"
    # )
    # logger.info(
    #     f"  Total train batch size (w. parallel, distributed & accumulation) = {args.total_batch_size}"
    # )
    # logger.info(
    #     f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    # logger.info(
    #     f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    # logger.info(f"  Total optimization steps = {args.max_train_steps}")

    # progress_bar = tqdm(
    #     range(args.max_train_steps),
    #     leave=False,
    #     disable=not accelerator.is_local_main_process,
    #     desc="Training: ",
    # )
    # global_steps = 0
    # tr_loss, logging_loss = 0.0, 0.0
    # max_f1 = 0.0
    # writer = get_writer(args)
    # model.train()

    # logger.info("**********  Configuration Arguments **********")
    # for arg, value in sorted(vars(args).items()):
    #     logger.info(f"{arg}: {value}")
    # logger.info("**************************************************")

    # write_json(vars(args), os.path.join(args.output_dir, "args.json"))
    # for epoch in range(args.num_train_epochs):
    #     for step, batch in enumerate(train_dataloader):
    #         outputs = model(**batch)
    #         if isinstance(outputs, dict):
    #             loss = outputs["loss"]
    #         else:
    #             loss = outputs[0]
    #         loss = loss / args.gradient_accumulation_steps
    #         tr_loss += loss.item()           
    #         accelerator.backward(loss)

    #         if (
    #             step % args.gradient_accumulation_steps == 0
    #             or step == len(train_dataloader) - 1
    #         ):
    #             accelerator.clip_grad_norm_(
    #                 model.parameters(), args.max_grad_norm)
    #             optimizer.step()
    #             lr_scheduler.step()
    #             optimizer.zero_grad()
    #             progress_bar.update(1)
    #             global_steps += 1

    #             if args.logging_steps > 0 and global_steps % args.logging_steps == 0:
    #                 writer.add_scalar(
    #                     "lr", lr_scheduler.get_last_lr()[-1], global_steps
    #                 )
    #                 writer.add_scalar(
    #                     "loss",
    #                     (tr_loss - logging_loss) / args.logging_steps,
    #                     global_steps,
    #                 )
    #                 logger.info(
    #                     "global_steps {} - lr: {:.8f}  loss: {:.8f}".format(
    #                         global_steps,
    #                         lr_scheduler.get_last_lr()[-1],
    #                         (tr_loss - logging_loss) / args.logging_steps,
    #                     )
    #                 )
    #                 accelerator.print(
    #                     "global_steps {} - lr: {:.8f}  loss: {:.8f}".format(
    #                         global_steps,
    #                         lr_scheduler.get_last_lr()[-1],
    #                         (tr_loss - logging_loss) / args.logging_steps,
    #                     )
    #                 )
    #                 logging_loss = tr_loss

    #             if (
    #                 args.save_steps > 0 and global_steps % args.save_steps == 0
    #             ) or global_steps == args.max_train_steps:
    #                 logger.info(
    #                     f"********** Evaluate Step {global_steps} **********")
    #                 accelerator.print("##--------------------- Dev")
    #                 logger.info("##--------------------- Dev")
    #                 dev_metric = evaluate(
    #                     args, model, dev_dataloader, accelerator, global_steps, 0, True
    #                 )
    #                 accelerator.print("-" * 80)
    #                 logger.info("-" * 80)
    #                 accelerator.print(f"global_steps = {global_steps}")
    #                 logger.info(f"global_steps = {global_steps}")
    #                 for k, v in dev_metric.items():
    #                     accelerator.print(f"{k} = {v}")
    #                     logger.info(f"{k} = {v}")
    #                     writer.add_scalar(
    #                         f"dev/{k}",
    #                         v,
    #                         global_steps,
    #                     )
    #                 accelerator.print("-" * 80)
    #                 logger.info("-" * 80)
    #                 accelerator.print("**--------------------- Dev End")
    #                 logger.info("**--------------------- Dev End")

    #                 f1 = dev_metric["f1"]
    #                 if f1 >= max_f1:
    #                     max_f1 = f1
    #                     savefile = Path(args.output_dir) / "val_results.txt"
    #                     savefile.write_text(
    #                         pformat(dev_metric), encoding="utf-8")

    #                 output_dir = os.path.join(
    #                     args.output_dir, "ckpt", f"step-{global_steps}-spo-f1-{f1}"
    #                 )

    #                 os.makedirs(output_dir, exist_ok=True)
    #                 accelerator.wait_for_everyone()
    #                 tokenizer.save_pretrained(output_dir)
    #                 if hasattr(model, "save_pretrained"):
    #                     accelerator.unwrap_model(model).save_pretrained(
    #                         output_dir, save_function=accelerator.save
    #                     )
    #                 try_remove_old_ckpt(args.output_dir, topk=args.topk)
    #                 logger.info("*************************************")

    #         if global_steps >= args.max_train_steps:
    #             return


if __name__ == "__main__":
    main()
